package pir

import (
	"fmt"
	"math"
)

type DBinfo struct {
	Num        uint64 // number of DB entries.
	Row_length uint64 // number of bits per DB entry.

	Packing uint64 // number of DB entries per Z_p elem, if log(p) > DB entry size.
	Ne      uint64 // number of Z_p elems per DB entry, if DB entry size > log(p).

	X uint64 // tunable param that governs communication,
	// must be in range [1, ne] and must be a divisor of ne;
	// represents the number of times the scheme is repeated.
	P    uint64 // plaintext modulus.
	Logq uint64 // (logarithm of) ciphertext modulus.

	// For in-memory DB compression
	Basis     uint64
	Squishing uint64
	Cols      uint64
}

type Database struct {
	Info DBinfo
	Data *Matrix
}

func (DB *Database) Squish() {

	DB.Info.Basis = 10
	DB.Info.Squishing = 3
	DB.Info.Cols = DB.Data.Cols
	DB.Data.Squish(DB.Info.Basis, DB.Info.Squishing)

	if (DB.Info.P > (1 << DB.Info.Basis)) || (DB.Info.Logq < DB.Info.Basis*DB.Info.Squishing) {
		panic("Bad params")
	}
}

func (DB *Database) Unsquish() {
	DB.Data.Unsquish(DB.Info.Basis, DB.Info.Squishing, DB.Info.Cols)
}

func ReconstructElem(vals []uint64, index uint64, info DBinfo) uint64 {
	q := uint64(1 << info.Logq)

	for i, _ := range vals {
		vals[i] = (vals[i] + info.P/2) % q
		vals[i] = vals[i] % info.P
	}

	val := Reconstruct_from_base_p(info.P, vals)
	fmt.Println("info.Packing: ", info.Packing)
	if info.Packing > 0 {
		val = Base_p((1 << info.Row_length), val, index%info.Packing)
	}

	return val
}

func ReconstructElemVec(vals []uint64, index uint64, info DBinfo) []uint64 {
	q := uint64(1 << info.Logq)

	for i, _ := range vals {
		vals[i] = (vals[i] + info.P/2) % q
		vals[i] = vals[i] % info.P
	}

	val := vals

	return val
}
func (DB *Database) GetElem(i uint64) uint64 {
	if i >= DB.Info.Num {
		panic("Index out of range")
	}

	col := i % DB.Data.Cols
	row := i / DB.Data.Cols

	if DB.Info.Packing > 0 {
		new_i := i / DB.Info.Packing
		col = new_i % DB.Data.Cols
		row = new_i / DB.Data.Cols
	}

	var vals []uint64
	for j := row * DB.Info.Ne; j < (row+1)*DB.Info.Ne; j++ {
		vals = append(vals, DB.Data.Get(j, col))
	}
	fmt.Println("vals: ", vals)
	return ReconstructElem(vals, i, DB.Info)
}

func (DB *Database) GetElemVec(i uint64) []uint64 {
	if i >= DB.Info.Num {
		panic("Index out of range")
	}

	col := i % DB.Data.Cols
	row := i / DB.Data.Cols

	if DB.Info.Packing > 0 {
		new_i := i / DB.Info.Packing
		col = new_i % DB.Data.Cols
		row = new_i / DB.Data.Cols
	}

	var vals []uint64
	for j := row * DB.Info.Ne; j < (row+1)*DB.Info.Ne; j++ {
		vals = append(vals, DB.Data.Get(j, col))
	}
	return ReconstructElemVec(vals, i, DB.Info)
}

func ApproxSquareDatabaseDims(N, row_length, p uint64) (uint64, uint64) {
	db_elems, elems_per_entry, _ := Num_DB_entries(N, row_length, p)
	l := uint64(math.Floor(math.Sqrt(float64(db_elems))))
	rem := l % elems_per_entry
	if rem != 0 {
		l += elems_per_entry - rem
	}

	m := uint64(math.Ceil(float64(db_elems) / float64(l)))
	return l, m
}
func ApproxDatabaseDimsWithLongText(N, row_length, p uint64) (uint64, uint64) {
	elems_per_entry := Compute_num_entries_base_p(p, row_length)
	l := elems_per_entry
	rem := l % elems_per_entry
	if rem != 0 {
		l += elems_per_entry - rem
	}
	return l, N
}

func ApproxDatabaseDims(N, row_length, p, lower_bound_m uint64) (uint64, uint64) {
	l, m := ApproxSquareDatabaseDims(N, row_length, p)
	if m >= lower_bound_m {
		return l, m
	}

	m = lower_bound_m
	db_elems, elems_per_entry, _ := Num_DB_entries(N, row_length, p)
	l = uint64(math.Ceil(float64(db_elems) / float64(m)))

	rem := l % elems_per_entry
	if rem != 0 {
		l += elems_per_entry - rem
	}

	return l, m
}

func SetupDB(Num, row_length uint64, p *Params) *Database {
	if (Num == 0) || (row_length == 0) {
		panic("Empty database!")
	}

	D := new(Database)

	D.Info.Num = Num
	D.Info.Row_length = row_length
	D.Info.P = p.P
	D.Info.Logq = p.Logq

	db_elems, elems_per_entry, entries_per_elem := Num_DB_entries(Num, row_length, p.P)
	//fmt.Println("db_elems : ", db_elems, " elems_per_entry : ", elems_per_entry, " entries_per_elem : ", entries_per_elem)
	D.Info.Ne = elems_per_entry
	D.Info.X = D.Info.Ne
	D.Info.Packing = entries_per_elem

	for D.Info.Ne%D.Info.X != 0 {
		D.Info.X += 1
	}

	D.Info.Basis = 0
	D.Info.Squishing = 0

	fmt.Printf("Total packed DB size is ~%f MB\n",
		float64(p.L*p.M)*math.Log2(float64(p.P))/(1024.0*1024.0*8.0))
	//fmt.Printf("Real packed DB size is   %d MB\n",
	//	uint64(Num*row_length/(1024.0*1024.0*8.0)))

	if db_elems > p.L*p.M {
		panic("Params and database size don't match")
	}

	if p.L%D.Info.Ne != 0 {
		panic("Number of DB elems per entry must divide DB height")
	}

	return D
}

func MakeRandomDB(Num, row_length uint64, p *Params) *Database {
	D := SetupDB(Num, row_length, p)
	//fmt.Println("p.L : ", p.L, ", p.M : ", p.M, ", p.P : ", p.P)
	seed128 := GenerateSeed128()
	D.Data = MatrixRand(p.L, p.M, p.Logp, 0, seed128)
	D.Data.Sub(p.P / 2)
	return D
}

func MakeDBfromStrVec(Num, row_length uint64, p *Params, vals []string) *Database {
	D := SetupDB(Num, row_length, p)
	D.Data = MatrixZeros(p.L, p.M)

	if uint64(len(vals)) != Num {
		panic("Bad input DB")
	}
	/*
		if D.Info.Packing > 0 {
			at := uint64(0)
			cur := uint64(0)
			coeff := uint64(1)
			for i, elem := range vals {
				cur += (elem * coeff)
				coeff *= (1 << row_length)
				if ((i+1)%int(D.Info.Packing) == 0) || (i == len(vals)-1) {
					D.Data.Set(cur, at/p.M, at%p.M)
					at += 1
					cur = 0
					coeff = 1
				}
			}
		} else {*/
	for i, elem := range vals {
		arr := StringToUint64ByBits(elem, p.Logp)
		//fmt.Println("len(arr)", len(arr), "p.Logp: ", p.Logp)
		//fmt.Println(arr)
		for j := uint64(0); j < D.Info.Ne; j++ {
			D.Data.Set(arr[j], (uint64(i)/p.M)*D.Info.Ne+j, uint64(i)%p.M)
		}
	}
	//}
	D.Data.Sub(p.P / 2)

	return D
}
